import torch
import torch.nn as nn
import sys
import os
parent_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_folder)
from Record.file_management import read_obj_dumps, strip_instance
import argparse
from Vae.data_utils.data_util import _CustomDataParallel
from Vae.models.cnn_vae import CNNVAE
import numpy as np
import PIL.Image as Image

from Vae.data_utils.vae_dataset import ImageDataset
import glob
from tqdm import tqdm
import pickle

CHUNK_SIZE = 500000
def record_vae_state(args):
    device='cuda'
    model = CNNVAE(latent_dim=args.latent_dim, nc=3 * args.frame_stack if not args.use_flow else 5, fit_linear=args.fit_linear).to(device)
    model = _CustomDataParallel(model)
    model.load_state_dict(torch.load(args.checkpoint_path))
    model.eval()
    obj_data = read_obj_dumps(args.load_rollouts, i=0, rng=-1, filename='object_dumps.txt')
    file_paths = glob.glob(args.load_rollouts+'/*/state*_*.png')

    print(len(file_paths), int(np.ceil(len(file_paths) / CHUNK_SIZE)))
    dataset = ImageDataset(file_paths, obj_data, frame_stack=args.frame_stack, split='full', ret_frame_info=True, use_flow=args.use_flow, fit_linear=args.fit_linear)
    encodings = [{} for _ in range(len(obj_data))] # list of dictionaries: TIME STEP = name: encodin
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=2048, shuffle=False, num_workers=4, drop_last=False) # reduced number of workers because of memory issues

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Encoding"):
            x, frame_idxs, obj_names = data['image'], data['frame_number'], data['obj_name']
            x = x.to(device)
            x_hat, mean, log_var = model(x)
            # mean = model.forward_fit_linear(x)

            B = x.shape[0]
            for i in range(B):
                frame_idx = frame_idxs[i].item()
                obj_name = obj_names[i]
                encodings[frame_idx][obj_name] = mean[i].detach().cpu().numpy()
    
    non_empty_encodings = sum([len(x) > 0 for x in encodings])
    print("nonempty encodings: ", non_empty_encodings)

    # Save encodings
    if len(args.save_rollouts) > 0:

        with open(os.path.join(args.save_rollouts, 'encodings_stack5_vae_z128.pkl'), 'wb') as f:
            pickle.dump(encodings, f)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='placeholder')
    parser.add_argument('--load_rollouts', default="",
                        help='where to load the rollouts from')
    parser.add_argument('--checkpoint_path', default="",
                        help='where to load the model from')
    parser.add_argument('--save_rollouts', default="",
                        help='where to save the rollouts to')
    args = parser.parse_args()
    args.latent_dim = 128
    args.frame_stack = 5
    args.use_flow = False
    # args.load_rollouts = "/hdd/datasets/object_data/box2d/default/"
    # args.checkpoint_path = "/hdd/datasets/object_data/box2d/nets/best_model_ctrl_poly.pth"
    args.load_rollouts = "datasets/box2d_default"
    args.save_rollouts = "datasets/box2d_default"
    # args.checkpoint_path = "data/20240111-200348_default_cnn_vae_z10_full_dataset/model_checkpoint_20.pth"
    # args.checkpoint_path = "data/20240118-133559_default_cnn_vae_withflow_z10_full_dataset/best_model.pth"
    # args.checkpoint_path = "data/20240207-154656_default_cnn_fitlinear_withflow_z10_full_dataset/model_checkpoint_1.pth"
    # args.checkpoint_path = "data/20240226-145016_default_cnn_fitlinear_frame5_z10_full_dataset/model_checkpoint_39.pth"
    # args.checkpoint_path = "data/20240226-163253_default_cnn_vae_frame5_z10_folder_1/model_checkpoint_950.pth"
    # args.checkpoint_path = "data/20240226-222823_default_cnn_vae_frame5_z128_folder_1/model_checkpoint_950.pth"
    args.checkpoint_path = "data/20240412-223520_default_cnn_vae_frame5_z128_full_data/model_checkpoint_49.pth"
    args.fit_linear = False
    record_vae_state(args)